import copy
import json

import torch
from torch.utils.data import Dataset
import datasets
import pandas as pd
import os
import numpy as np


def get_preprocessed_geoquery(split_method, max_words=256):
    IGNORE_INDEX = -100

    df = pd.read_json("data/geo880/geo880_v2_all.jsonl", lines=True)
    if split_method == "standard":
        split_file = "data/geo880/geo880_splits/standard/split.json"
    elif split_method == "length_1":
        split_file = "data/geo880/geo880_splits/length_1/split.json"
    elif split_method == "template_1":
        split_file = "data/geo880/geo880_splits/template_1/split.json"
    elif split_method == "template_2":
        split_file = "data/geo880/geo880_splits/template_2/split.json"
    elif split_method == "template_3":
        split_file = "data/geo880/geo880_splits/template_3/split.json"
    elif split_method == "tmcd_1":
        split_file = "data/geo880/geo880_splits/tmcd_1/split.json"
    elif split_method == "tmcd_2":
        split_file = "data/geo880/geo880_splits/tmcd_2/split.json"
    elif split_method == "tmcd_3":
        split_file = "data/geo880/geo880_splits/tmcd_3/split.json"
    else:
        raise NotImplementedError

    with open(split_file, 'r') as f:
        split_data = json.load(f)
    train_data = split_data['train']
    test_data = split_data['test']

    df_train = df[df['qid'].isin(train_data)]
    df_test = df[df['qid'].isin(test_data)]

    dataset_train = datasets.Dataset.from_pandas(df_train)
    dataset_test = datasets.Dataset.from_pandas(df_test)

    return dataset_train, dataset_test